Skip to content

Conversation

cboss6
Copy link
Contributor

@cboss6 cboss6 commented Sep 27, 2025

Description:
PR-23745 introduced the round-robin expert placement strategy for MoE models with multiple expert groups, providing a simple yet effective way to distribute experts evenly across devices.
This PR extends that work by ensuring full compatibility with EPLB (Expert Parallel Load Balancing). With this enhancement, round-robin placement can now be seamlessly combined with dynamic expert load balancing, enabling more flexible expert scheduling while maintaining balanced utilization and performance.

Performance

Conclusion: With configurations list below, when eplb is enabled, the round-robin strategy improves avg. throughput and end-to-end latency by approximately 3% than default linear strategy.

Test Platform:
Vllm version: vllm/vllm-openai:nightly-8c546102658f97b10d13bcf25193b65edc6ea6ff
Model: DeepSeek-V2-Chat-0628,
GPU: H20 * 8
Serving mode config :
python3 -u -m vllm.entrypoints.openai.api_server
--model ${MODEL_PATH}
--trust-remote-code
--gpu-memory-utilization 0.85
-tp 8 \
--enable-expert-parallel
--enable-eplb
--expert-placement-strategy "round_robin"

Benchmark config: input_len=1024, output_len=128, request_rate=4, max_concurrency=4, num_prompts=32:
python3 ./bench_serving.py
--backend vllm
--dataset-name random
--model ${MODEL_PATH}
--random-input-len 1024
--random-output-len 128
--random-range-ratio 0.5
--tokenizer ./tokenizer
--dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json
--request-rate 4
--max-concurrency 4
--num-prompts 32
--base-url http://127.0.0.1:8000
--port 8000

Clipboard_Screenshot_1759048641

Accuracy Test

Tested with Deepseek-v2-chat-0628 on h20*8 with following serving cmd:

python3 -u -m vllm.entrypoints.openai.api_server \
            --model ${model_path} \
            --trust-remote-code \
            --gpu-memory-utilization 0.85 \
            -tp 8 \ 
            --enable-expert-parallel \
            --enable-eplb \
            --expert-placement-strategy "round_robin" \

Note: Deepseek-v2 has a bad behavior on our chosen dataset, just to make sure this PR has no impact on accuracy.

Dataset vllm v0.10.1.1 This PR
Aime24 13.33% 20.00%
Gpqa 41.91% 44.44%
Math500 72.20% 72.40%
Clipboard_Screenshot_1759050538 ```

@cboss6 cboss6 requested a review from mgoin as a code owner September 27, 2025 04:31
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables the round-robin expert placement strategy for MoE models with EPLB enabled. The changes involve refactoring the expert placement strategy logic into a utility function and updating the EPLB state creation to support the round-robin strategy. The refactoring improves code organization. However, I've found a critical bug in the implementation of the round-robin placement logic that occurs when the number of experts is not divisible by the number of expert parallel ranks. This can lead to incorrect model behavior. A fix is suggested to ensure correctness.

@cboss6 cboss6 changed the title [Feat][EPLB] Enable Round-robin expert placement strategy with eplb enabled. [Feat][EPLB][Perf] Enable Round-robin expert placement strategy with eplb enabled. Sep 27, 2025
@cboss6 cboss6 changed the title [Feat][EPLB][Perf] Enable Round-robin expert placement strategy with eplb enabled. [Feat][EPLB][Perf] Enable Round-robin expert placement strategy while eplb is enabled. Sep 27, 2025
@cboss6 cboss6 requested a review from 22quinn as a code owner September 28, 2025 07:28
@abmfy
Copy link
Member

abmfy commented Sep 30, 2025

Thanks for the contribution!
@tlrmchlsmth Do you think we should consider using this strategy as the default when EPLB is enabled? My concern is that since this placement only affects the stage before the first EPLB rearrangement, it won’t actually bring any improvement. I’d prefer to keep the code simple.

@cboss6
Copy link
Contributor Author

cboss6 commented Sep 30, 2025

I don’t agree with the view that round-robin offers no improvement. Since round-robin mainly serves as a better initialization step before being adjusted by EPLB, my performance tests focused on the early stage (after at least one EPLB adjustment), where the average throughput and E2E improvement was around 2.5–3%.

If we only consider the initial state, as far as I know, round placement is generally better than linear placement in multi–expert-group models without redundant experts. Furthermore, since there’s no risk with this PR when EPLB is enabled, I believe it’s reasonable to make round the default initial state instead of linear.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants